{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tree Learning – implementation and application of decision trees"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook gives you the opportunity to implement some key components of decision tree learning and run your algorithm on a benchmark dataset. So restrictions will be made to simplify the problem. The notebook concludes by asking you to run the decision tree learning (and tree-based method of \"Random Forests\") from scikit-learn for comparison.\n",
    "\n",
    "Make sure you have the Titanic dataset (\"```titanic.csv```\") in the directory from where you are running the notebook before you start."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from matplotlib import pyplot as plt\n",
    "%matplotlib inline\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 891 entries, 0 to 890\n",
      "Data columns (total 12 columns):\n",
      "PassengerId    891 non-null int64\n",
      "Survived       891 non-null int64\n",
      "Pclass         891 non-null int64\n",
      "Name           891 non-null object\n",
      "Sex            891 non-null object\n",
      "Age            714 non-null float64\n",
      "SibSp          891 non-null int64\n",
      "Parch          891 non-null int64\n",
      "Ticket         891 non-null object\n",
      "Fare           891 non-null float64\n",
      "Cabin          204 non-null object\n",
      "Embarked       889 non-null object\n",
      "dtypes: float64(2), int64(5), object(5)\n",
      "memory usage: 83.6+ KB\n"
     ]
    }
   ],
   "source": [
    "ds = pd.read_csv('titanic.csv')\n",
    "ds.info()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Preprocessing "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To simplify things we will focus on the supplied dataset and start by doing some preprocessing, including feature selection, turning categorical data to numeric, and some other stuff. Spend about 10 minutes and go through this if you have any doubts. We start by inspecting the dataset. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PassengerId</th>\n",
       "      <th>Survived</th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Name</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "      <th>SibSp</th>\n",
       "      <th>Parch</th>\n",
       "      <th>Ticket</th>\n",
       "      <th>Fare</th>\n",
       "      <th>Cabin</th>\n",
       "      <th>Embarked</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>Braund, Mr. Owen Harris</td>\n",
       "      <td>male</td>\n",
       "      <td>22.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>A/5 21171</td>\n",
       "      <td>7.2500</td>\n",
       "      <td>NaN</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n",
       "      <td>female</td>\n",
       "      <td>38.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>PC 17599</td>\n",
       "      <td>71.2833</td>\n",
       "      <td>C85</td>\n",
       "      <td>C</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>Heikkinen, Miss. Laina</td>\n",
       "      <td>female</td>\n",
       "      <td>26.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>STON/O2. 3101282</td>\n",
       "      <td>7.9250</td>\n",
       "      <td>NaN</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n",
       "      <td>female</td>\n",
       "      <td>35.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>113803</td>\n",
       "      <td>53.1000</td>\n",
       "      <td>C123</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>Allen, Mr. William Henry</td>\n",
       "      <td>male</td>\n",
       "      <td>35.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>373450</td>\n",
       "      <td>8.0500</td>\n",
       "      <td>NaN</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   PassengerId  Survived  Pclass  \\\n",
       "0            1         0       3   \n",
       "1            2         1       1   \n",
       "2            3         1       3   \n",
       "3            4         1       1   \n",
       "4            5         0       3   \n",
       "\n",
       "                                                Name     Sex   Age  SibSp  \\\n",
       "0                            Braund, Mr. Owen Harris    male  22.0      1   \n",
       "1  Cumings, Mrs. John Bradley (Florence Briggs Th...  female  38.0      1   \n",
       "2                             Heikkinen, Miss. Laina  female  26.0      0   \n",
       "3       Futrelle, Mrs. Jacques Heath (Lily May Peel)  female  35.0      1   \n",
       "4                           Allen, Mr. William Henry    male  35.0      0   \n",
       "\n",
       "   Parch            Ticket     Fare Cabin Embarked  \n",
       "0      0         A/5 21171   7.2500   NaN        S  \n",
       "1      0          PC 17599  71.2833   C85        C  \n",
       "2      0  STON/O2. 3101282   7.9250   NaN        S  \n",
       "3      0            113803  53.1000  C123        S  \n",
       "4      0            373450   8.0500   NaN        S  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Do we need all the features ? No."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Survived</th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "      <th>SibSp</th>\n",
       "      <th>Parch</th>\n",
       "      <th>Fare</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>male</td>\n",
       "      <td>22.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>7.2500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>female</td>\n",
       "      <td>38.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>71.2833</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>female</td>\n",
       "      <td>26.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7.9250</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>female</td>\n",
       "      <td>35.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>53.1000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>male</td>\n",
       "      <td>35.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>8.0500</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Survived  Pclass     Sex   Age  SibSp  Parch     Fare\n",
       "0         0       3    male  22.0      1      0   7.2500\n",
       "1         1       1  female  38.0      1      0  71.2833\n",
       "2         1       3  female  26.0      0      0   7.9250\n",
       "3         1       1  female  35.0      1      0  53.1000\n",
       "4         0       3    male  35.0      0      0   8.0500"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cols_to_drop = [\n",
    "    'PassengerId',\n",
    "    'Name',\n",
    "    'Ticket',\n",
    "    'Cabin',\n",
    "    'Embarked',\n",
    "]\n",
    "\n",
    "df = ds.drop(cols_to_drop, axis=1)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Another simplification will be to treat all attributes as numeric. So we need to convert any that are not."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Survived</th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "      <th>SibSp</th>\n",
       "      <th>Parch</th>\n",
       "      <th>Fare</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>7.2500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>38.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>71.2833</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>26.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7.9250</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>35.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>53.1000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>35.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>8.0500</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Survived  Pclass  Sex   Age  SibSp  Parch     Fare\n",
       "0         0       3    0  22.0      1      0   7.2500\n",
       "1         1       1    1  38.0      1      0  71.2833\n",
       "2         1       3    1  26.0      0      0   7.9250\n",
       "3         1       1    1  35.0      1      0  53.1000\n",
       "4         0       3    0  35.0      0      0   8.0500"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def convert_sex_to_num(s):\n",
    "    if s=='male':\n",
    "        return 0\n",
    "    elif s=='female':\n",
    "        return 1\n",
    "    else:\n",
    "        return s\n",
    "\n",
    "df.Sex = df.Sex.map(convert_sex_to_num)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's overview the preprocessed dataset now with some standard commands."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Survived</th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "      <th>SibSp</th>\n",
       "      <th>Parch</th>\n",
       "      <th>Fare</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>714.000000</td>\n",
       "      <td>714.000000</td>\n",
       "      <td>714.000000</td>\n",
       "      <td>714.000000</td>\n",
       "      <td>714.000000</td>\n",
       "      <td>714.000000</td>\n",
       "      <td>714.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>0.406162</td>\n",
       "      <td>2.236695</td>\n",
       "      <td>0.365546</td>\n",
       "      <td>29.699118</td>\n",
       "      <td>0.512605</td>\n",
       "      <td>0.431373</td>\n",
       "      <td>34.694514</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>0.491460</td>\n",
       "      <td>0.838250</td>\n",
       "      <td>0.481921</td>\n",
       "      <td>14.526497</td>\n",
       "      <td>0.929783</td>\n",
       "      <td>0.853289</td>\n",
       "      <td>52.918930</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.420000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>20.125000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>8.050000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>28.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>15.741700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>38.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>33.375000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>80.000000</td>\n",
       "      <td>5.000000</td>\n",
       "      <td>6.000000</td>\n",
       "      <td>512.329200</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         Survived      Pclass         Sex         Age       SibSp       Parch  \\\n",
       "count  714.000000  714.000000  714.000000  714.000000  714.000000  714.000000   \n",
       "mean     0.406162    2.236695    0.365546   29.699118    0.512605    0.431373   \n",
       "std      0.491460    0.838250    0.481921   14.526497    0.929783    0.853289   \n",
       "min      0.000000    1.000000    0.000000    0.420000    0.000000    0.000000   \n",
       "25%      0.000000    1.000000    0.000000   20.125000    0.000000    0.000000   \n",
       "50%      0.000000    2.000000    0.000000   28.000000    0.000000    0.000000   \n",
       "75%      1.000000    3.000000    1.000000   38.000000    1.000000    1.000000   \n",
       "max      1.000000    3.000000    1.000000   80.000000    5.000000    6.000000   \n",
       "\n",
       "             Fare  \n",
       "count  714.000000  \n",
       "mean    34.694514  \n",
       "std     52.918930  \n",
       "min      0.000000  \n",
       "25%      8.050000  \n",
       "50%     15.741700  \n",
       "75%     33.375000  \n",
       "max    512.329200  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = df.dropna()\n",
    "data.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.axes._subplots.AxesSubplot at 0x1a19b88470>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAD8CAYAAACcjGjIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAH9JJREFUeJzt3Xu0HGWZ7/HvjyQIIZxgwsUgmYAQEAIhQwIYOGoEncGFx3jhDgqjEnGJspzFeECUiSgLdHSU+zIKh8sMQyAoRmRxGSBMvOWmCSFhgMjFYAIarhMICdn7OX/Uu6Ho6t67d7prd3fy+7BqUV31VtVTnd711Pu+dVFEYGZmlrdVqwMwM7P24+RgZmYFTg5mZlbg5GBmZgVODmZmVuDkYGZmBU4OZmZW4ORgZmYFTg5mZlYwuNUBDJTX1zzelreCH7L/p1odQlVzT9mx1SFUdeGs7VodQk1vj/Y81xqzoS1/+hww9MVWh1DTfn/8pRpdR3+OOUN2fFfD22u29vw1m5lZS20xNQczswHV3dXqCBri5GBmVoauja2OoCFODmZmJYjobnUIDXFyMDMrQ7eTg5mZVXLNwczMCtwhbWZmBa45mJlZpfDVSmZmVtDhHdK+Q9rMrAzRXf/QB0lHSXpE0gpJ51SZP0bSvZIelDRH0m6Nhl9XcpB0nqRlacOLJR3a6IYlfbTaTm7iutY2Yz1mZk3T3VX/0AtJg4ArgA8D+wEnStqvotj3gOsjYjxwAXBRo+H32awkaTLwEeCgiFgvaUdg63pWLmlwRFRteIuI2cDs/gRrZtYxmtchfQiwIiIeB5B0EzAVWJ4rsx/wlTR+P3Bboxutp+YwClgTEesBImJNRKyS9GRKFEiaJGlOGp8uaYaku4HrJc2TNK5nZanKM1HSaZIulzQ8rWurNH+opJWShkjaU9KdkhZJmivp3anMHpJ+K2mBpG81+iWYmTVd18a6B0nTJC3MDdNya3onsDL3+ek0LW8J8Mk0/nFge0kjGwm/nuRwNzBa0qOSrpT0/jqWmQhMjYiTgJuA4wAkjQJ2jYhFPQUj4iWyHetZ7/8B7oqI14EZwJciYiJwNnBlKnMJcFVEHAw8U0c8ZmYDq7u77iEiZkTEpNwwI7emao/zrnwc+NnA+yX9gexY+megocul+kwOEbGW7GA/DfgrMFPSaX0sNjsi1qXxm4Fj0/hxwC1Vys8Ejk/jJ6RtDAMOA26RtBj4EVktBuBw4D/S+A21gshn459c/x+1ipmZNV1EV91DH54GRuc+7waseuu2YlVEfCIi/hY4L017qZH467qUNbLo5wBzJC0FTiXLSj3JZZuKRV7JLftnSc9JGk+WAD5fZROzgYskjSBLRPcB2wEvRsSEWmHVEfcMstpH277sx8w2U83rc1gAjJW0B1mN4ATgpHyB1MT/fGRP+zsXuKbRjfZZc5C0j6SxuUkTgKeAJ8kO5PBmW1ctNwFfBYZHxNLKmal2Mp+suej2iOiKiJeBJyQdm+KQpAPTIr8m+4IATu5rH8zMBlw/mpV6ky7qORO4C3gYuDkilkm6QNJHU7EpwCOSHgV2AS5sNPx6ag7DgMsk7UBWW1hB1sS0L3C1pK8B8/pYxyyyA39vncczyZqcpuSmnQxcJenrwBCyJLMEOAu4UdJZwK117IOZ2cBq4uMzIuIO4I6KaefnxmeRHWebps/kkDqPD6syay6wd5Xy06tMe7ZyWxFxLXBt7vMsKjpeIuIJ4Kgq63sCmJybdHHtPTAza4Gu11sdQUP8+AwzszJ0+OMznBzMzMrgp7KamVmBaw5mZlbg5GBmZpXCHdJmZlbgPgczMytws5KZmRW45mBmZgWuOZiZWYFrDp3hkP0/1eoQqpr/UM0njrfUFyZ9tdUhVDX/tSdaHUJNZw/as9UhVPXc4GqvA2i98SsXtzqEmhp6EcIbK2nKWlpmi0kOZmYDyjUHMzMrcJ+DmZkVuOZgZmYFrjmYmVmBaw5mZlbgq5XMzKwgotURNMTJwcysDO5zMDOzAicHMzMrcIe0mZkVdHW1OoKGlJIcJHUBS9P6HwZOjYhXa5SdDqyNiO+VEYuZWUt0eLPSViWtd11ETIiI/YENwBklbcfMrD11d9c/tKGykkPeXGAvAEmflvSgpCWSCo8jlXS6pAVp/q2Shqbpx0p6KE3/rzRtnKT5khandY4dgH0xM6tPdNc/tKFS+xwkDQY+DNwpaRxwHnB4RKyRNKLKIj+NiB+nZb8NfBa4DDgf+PuI+LOkHVLZM4BLIuLfJW0NDCpzX8zM+iO6O/s+h7JqDttKWgwsBP4EXA0cAcyKiDUAEfF8leX2lzRX0lLgZGBcmv5r4FpJp/NmEvgt8DVJ/xcYExHrKlcmaZqkhZIWrnn1mWbun5lZ79ysVFVPn8OEiPhSRGwABPSVSq8FzoyIA4BvAtsARMQZwNeB0cBiSSMj4kbgo8A64C5JR1SuLCJmRMSkiJi049B3NG3nzMz61NVV/9CGBqLPoce9wHGSRgLUaFbaHlgtaQhZzYFUds+ImBcR5wNrgNGS3gU8HhGXArOB8aXvgZlZvTq85jBg9zlExDJJFwIPpEtd/wCcVlHsG8A84CmyS2G3T9P/JXU4iyzJLAHOAU6R9DrwDHBB6TthZlavNj3o16uU5BARw2pMvw64rmLa9Nz4VcBVVZb7RJXVXZQGM7P24wfvmZlZQYfXHAayz8HMbMvRHfUPfZB0lKRHJK2QdE6NMsdJWi5pmaQbGw3fNQczszI06SokSYOAK4APAU8DCyTNjojluTJjgXPJ7iN7QdLOjW7XycHMrATRvGalQ4AVEfE4gKSbgKnA8lyZ04ErIuIFgIj4S6MbdbOSmVkZ+tGslL9hNw3Tcmt6J7Ay9/npNC1vb2BvSb+W9DtJRzUavmsOZmZl6MczkyJiBjCjxmxVW6Ti82BgLDAF2A2YK2n/iHix7iAquOZgZlaG5nVIP032dIgeuwGrqpT5eUS8HhFPAI+QJYtN5uRgZlaGjV31D71bAIyVtEd6yOgJZE+FyLsN+ACApB3JmpkebyR8NyuZmZWhSY/ijoiNks4E7iJ78Og16YkTFwALI2J2mvd3kpYDXcA/RcRzjWzXycHMrAxNfGR3RNwB3FEx7fzceAD/mIam2GKSw9xTdmx1CFV9YdJXWx1CVVct/G6rQ6jq3EnntTqEmoaub8/HJYwf/D+tDqGq+0Yc1uoQStXES1lbYotJDmZmA6rDX/bj5GBmVgYnBzMzK2jTl/jUy8nBzKwEnf4OaScHM7MyODmYmVmBr1YyM7MC1xzMzKzAycHMzCpFl5uVzMyskmsOZmZWqdMvZW2bR3ZLOi+9GPtBSYslHdrqmMzMNlnz3ufQEm1Rc5A0GfgIcFBErE/PI9+6xWGZmW26zu5yaJuawyhgTUSsB4iINRGxStJESQ9IWiTpLkmjJA2WtEDSFABJF0m6sJXBm5lVio3ddQ/tqF2Sw93AaEmPSrpS0vslDQEuA46JiInANcCFEbEROA24StKHgKOAb7YqcDOzqrr7MbShtmhWioi1kiYC7yV71d1M4NvA/sA9kiB7A9LqVH6ZpBuAXwCTI2JDtfVKmgZMA7jkQ+P5zPjdS94TM7NMp3dIt0VyAIiILmAOMEfSUuCLwLKImFxjkQOAF4FdelnnDGAGwNqzp3b2v5SZdZY2rRHUqy2alSTtI2lsbtIE4GFgp9RZjaQhksal8U8AI4H3AZdK2mGgYzYz6010R91DO2qXmsMw4LJ0kN8IrCBrDppBdvAfThbrDyU9C1wMHBkRKyVdDlwCnNqa0M3MqujwmkNbJIeIWARUe6HsGrLaQaW9c8teWlZcZmabKja2OoLGtEVyMDPb3IRrDmZmVuDkYGZmlVxzMDOzAicHMzMriC61OoSGODmYmZXANQczMyuIbtcczMysgmsOZmZWEOGag5mZVXDNoUNcOGu7VodQ1fzXnmh1CFWdO+m8VodQ1UUL2/e9TkcceHqrQ6hqDMNbHUJVU7V5v+yx21crmZlZpU7vkG6LR3abmW1uolt1D32RdJSkRyStkHROlflnSFoqabGkX0nar9H4nRzMzEoQUf/QG0mDgCuADwP7ASdWOfjfGBEHRMQE4LvAvzYav5uVzMxK0MRmpUOAFRHxOICkm4CpwPI3thXxcq78dkDDbxBycjAzK0F/LmXNv+8+mZFecwzwTmBlbt7TwKFV1vFF4B+BrYEj+htvJScHM7MSdPXjaqX8++6rqLaiQs0gIq4ArpB0EvB1Gnw7ppODmVkJmngT3NPA6Nzn3YBVvZS/Cbiq0Y26Q9rMrARNvFppATBW0h6StgZOAGbnC0gam/t4NPBYo/G75mBmVoK+rkKqfz2xUdKZwF3AIOCaiFgm6QJgYUTMBs6U9EHgdeAFGmxSAicHM7NSNPMmuIi4A7ijYtr5ufGzmraxxMnBzKwEXd2d3WrfNtFL+rikkPTuVsdiZtaoZt0E1yptkxyAE4FfkXW2mJl1tO5Q3UM7aovkIGkYcDjwWVJykLSVpCslLZN0u6Q7JB2T5k2U9ICkRZLukjSqheGbmRVEqO6hHbVLn8PHgDsj4lFJz0s6CHgXsDtwALAz8DBwjaQhwGXA1Ij4q6TjgQuBz7QmdDOzonZtLqpXuySHE4EfpvGb0uchwC0R0Q08I+n+NH8fYH/gHkmQXdq1utpK87ekHzXiYCZsv1dpO2BmlteuzUX1anlykDSS7Dkg+0sKsoN9AD+rtQiwLCIm97Xu/C3p5+5+UofncTPrJL5aqXHHANdHxJiI2D0iRgNPAGuAT6a+h12AKan8I8BOkiYDSBoiaVwrAjczqyX6MbSjltccyJqQLq6YdiuwL9kzRR4CHgXmAS9FxIbUMX2ppOFk+/BDYNnAhWxm1js3KzUoIqZUmXYpZFcxRcTa1PQ0H1ia5i8G3jeQcZqZ9Ue7XoVUr5Ynhz7cLmkHsueTfysinml1QGZm9ehudQANauvkUK1WYWbWCaLqaxg6R1snBzOzTrXRzUpmZlbJNQczMytwn4OZmRW45mBmZgWuOZiZWUGXaw5mZlapiW8JbQknBzOzEnS75tAZ3h7t8IzBorMH7dnqEKoaur49Hwd2xIGntzqEmu5b8uNWh1DV88f+Q6tDqGrR8s37HV3t+RdUvy0mOZiZDSR3SJuZWUG33KxkZmYVulodQIOcHMzMSuCrlczMrMBXK5mZWYGvVjIzswI3K5mZWYEvZTUzs4Iu1xzMzKySaw5mZlbQ6clhQB44JOk8ScskPShpsaRDJf1E0n5p/toay71H0ry0zMOSpg9EvGZmjQrVP7Sj0msOkiYDHwEOioj1knYEto6Iz9Wx+HXAcRGxRNIgYJ8yYzUza5Zm1hwkHQVcAgwCfhIRF1fMfxtwPTAReA44PiKebGSbA1FzGAWsiYj1ABGxJiJWSZojaVJPIUnfl/R7SfdK2ilN3hlYnZbriojlqex0STdIuk/SY5La91GdZrZF6urH0Jt0YnwF8GFgP+DEnlaXnM8CL0TEXsAPgO80Gv9AJIe7gdGSHpV0paT3VymzHfD7iDgIeAD45zT9B8Ajkn4m6fOStsktMx44GpgMnC9p1xL3wcysX7pV/9CHQ4AVEfF4RGwAbgKmVpSZStbSAjALOFJq7Ml/pSeHiFhLVtWZBvwVmCnptIpi3cDMNP5vwP9Oy14ATCJLMCcBd+aW+XlErIuINcD9ZF/gW0iaJmmhpIXz1j7WvJ0yM+tDdz+G/LEqDdNyq3onsDL3+ek0jWplImIj8BIwspH4B+RqpYjoAuYAcyQtBU7ta5Hcsn8ErpL0Y+CvkkZWlqnxmYiYAcwA+O6YUzr9bnYz6yD96XPIH6uqqFYDqDye1VOmX0qvOUjaR9LY3KQJwFNV4jgmjZ8E/Cote3SuajSWrHnuxfR5qqRtUrKYAiwoIXwzs00S/Rj68DQwOvd5N2BVrTKSBgPDgecbCH9Aag7DgMsk7QBsBFaQNTHNypV5BRgnaRFZdej4NP1TwA8kvZqWPTkiulK+mA/8Evgb4FsRUfllmZm1TBOfrbQAGCtpD+DPwAlkJ9F5s8laZH5LdqJ9X0Q0VHMoPTlExCLgsCqzpuTKDEuj36hY9oReVv1oREzrZb6ZWcs062U/EbFR0pnAXWSXsl4TEcskXQAsjIjZwNXADZJWkNUYejt21sV3SJuZlaC7iQ/tjog7gDsqpp2fG38NOLZpG6RDk0NETG91DGZmven0x2d0ZHIwM2t3nX55pJODmVkJXHMwM7OCjersuoOTg5lZCTo7NTg5mJmVws1KZmZW0MxLWVvBycHMrASdnRqcHMzMSuFmpQ4xZkN75vHnBrfnOwLHD/6fVodQ1RiGtzqEmp4/9h9aHUJVI275f60OoaoR489udQil6urwusMWkxzMzAaSaw5mZlYQrjmYmVkl1xzMzKzAl7KamVlBZ6cGJwczs1Js7PD04ORgZlYCd0ibmVmBO6TNzKzANQczMytwzcHMzAq6wjUHMzOr0On3OWxV9gYkdUlaLOkhSbdIGtqEdZ4m6fJmxGdmVobox3/tqPTkAKyLiAkRsT+wATij3gUlDSovLDOz8nT3Y2hHA5Ec8uYCewFIuk3SIknLJE3rKSBpraQLJM0DJks6WNJvJC2RNF/S9qnorpLulPSYpO8O8H6YmfWqm6h7aEcDlhwkDQY+DCxNkz4TEROBScCXJY1M07cDHoqIQ4H5wEzgrIg4EPggsC6VmwAcDxwAHC9pdJVtTpO0UNLC/3x1RVm7ZmZW4Galvm0raTGwEPgTcHWa/mVJS4DfAaOBsWl6F3BrGt8HWB0RCwAi4uWI2Jjm3RsRL0XEa8ByYEzlhiNiRkRMiohJHxy6Vxn7ZmZWVVdE3UM7GoirldZFxIT8BElTyGoBkyPiVUlzgG3S7NcioqunKLWfX7U+N96Fr7wyszbSrs1F9RroPocew4EXUmJ4N/CeGuX+m6xv4WAASdun5ikzs7bW6R3SrTrQ3gmcIelB4BGypqWCiNgg6XjgMknbkvU3fHDgwjQz2zTt2pdQr9KTQ0QMqzJtPVnndJ/lU39DZc3i2jT0lPlIo3GamTVTpzcruYnGzKwE0aYdzfVycjAzK0FXh9ccWtUhbWa2WRuom+AkjZB0T7oh+B5Jb69SZky66XhxuvG4zydVODmYmZUgIuoeGnQO2X1fY4F70+dKq4HD0m0FhwLnSNq1t5U6OZiZlWAAH58xFbgujV8HfKyyQERsSBcCAbyNOo79Tg5mZiXoz+Mz8o/6ScO0vrfwhl0iYjVA+v/O1QpJGp1uH1gJfCciVvW2UndIm5mVoD+PxYiIGcCMWvMl/SfwjiqzzuvHNlYC41Nz0m2SZkXEs7XKOzmYmZWgmfc5RETNm38lPStpVESsljQK+Esf61olaRnwXmBWrXJuVjIzK8EA9jnMBk5N46cCP68sIGm39JQJ0tVMh5M9naKmLabmcMDQF1sdQlXjVy5udQhV3TfisFaHUNVUbd3qEGpatHxUq0OoasT4s1sdQlUTH/xeq0Mo1QDeBHcxcLOkz5I9+fpYAEmTgDMi4nPAvsD3JQXZA02/FxFLa60QtqDkYGY2kAbq8RkR8RxwZJXpC4HPpfF7gPH9Wa+Tg5lZCfzgPTMzK+iKdn0Yd32cHMzMSuAH75mZWYEf2W1mZgXuczAzs4JuNyuZmVkl1xzMzKzAVyuZmVmBm5XMzKzAzUpmZlbgmsMmktQF5B/89LGIeLJF4ZiZNZVrDptuXXqfab9IGhQRXWUEZGbWLF0dfphqq/c5SNpd0lxJv0/DYWn6FEn3S7qRVNuQdIqk+ZIWS/qRpEEtDd7MLCci6h7aUSuTw7bpwL5Y0s/StL8AH4qIg4DjgUtz5Q8BzouI/STtm+YfnmofXcDJlRvIv5f15pf/VO7emJnlDODLfkrRbs1KQ4DLJfUc8PfOzZsfEU+k8SOBicACSQDbUuXVePn3si7f8+j2/Bcws81Su9YI6tVuVyt9BXgWOJCsVvNabt4ruXEB10XEuQMYm5lZ3Tr9aqW26nMAhgOrI6Ib+BRQqx/hXuAYSTsDSBohacwAxWhm1qfox3/tqN1qDlcCt0o6Frift9YW3hARyyV9Hbhb0lbA68AXgacGLFIzs1748RmbKCKGVZn2GG99z+m5afocYE5F2ZnAzPIiNDPbdO5zMDOzgk7vc3ByMDMrgWsOZmZW0K73L9TLycHMrASuOZiZWYGvVjIzswJ3SJuZWYGblczMrKBd73yul5ODmVkJXHMwM7OCTu9zUKdnt1aQNC09DryttGtc0L6xOa7+a9fY2jWuTtVuT2XtFNNaHUAN7RoXtG9sjqv/2jW2do2rIzk5mJlZgZODmZkVODlsmnZt12zXuKB9Y3Nc/deusbVrXB3JHdJmZlbgmoOZmRVsEclB0nmSlkl6UNJiSYc2YZ0flXROk+JbW/G5K8X5kKRbJA3tZdnpks5uRhyNKOM7bgZJH5cUkt7d4jgK34+kn0jaL81fW2O590ial5Z5WNL0JsZU9++sH+s8TdLlzYgvt86eOHuG3Zu5fqtus78JTtJk4CPAQRGxXtKOwNZ1Ljs4IjZWmxcRs4HZzYv0LdZFxIQUw78DZwD/WtK2GtbIdzwATgR+BZwATG9FALW+n4j4XB2LXwccFxFLJA0C9mliaJv8O5M0KCK6mhhLb96Isz8GOMbNzpZQcxgFrImI9QARsSYiVkl6Mv2RImmSpDlpfLqkGZLuBq5PZ23jelYmaY6kiT1nSJKGp3VtleYPlbRS0hBJe0q6U9IiSXN7zl4l7SHpt5IWSPpWH/HPBfZKy306nXkukXRDZUFJp6d1LpF0a8+ZoKRj09nhEkn/laaNkzQ/nYk9KGlsCd/xREkPpP2/S9IoSYNTjFNSHBdJurCBbdckaRhwOPBZsuSApK0kXZnO4m+XdIekY9K8QrxNCqXW9zNH0qRcvN+X9HtJ90raKU3eGVidluuKiOWp7HRJN0i6T9Jjkk5vMMb87+y29B0sk/TGvQOS1kq6QNI8YLKkgyX9Jv2u5kvaPhXdNf3uH5P03QbjqkrS7ulv6vdpOCxNnyLpfkk3AkvTtFNyv/UfpSRrfYmIzXoAhgGLgUeBK4H3p+lPAjum8UnAnDQ+HVgEbJs+fwX4ZhofBTyaxk8DLk/jPwc+kMaPB36Sxu8FxqbxQ4H70vhs4NNp/IvA2oqY16b/D07r/gIwDngkF/OIXLxnp/GRuXV8G/hSGl8KvDON75D+fxlwchrfumd/m/UdA0OA3wA75b6Xa9L4OOBh4EPAH8jOosv4tz8FuDqN/wY4CDgGuIPsxOgdwAtpWs14S/wNzgEmpfHI/Xucn/ttnZ9i/BnweWCb3L/7EmBbYEdgJbBrP+Mq/M4qflvbAg/1/K5SjMflfjOPAwenz/8rree0NH04sA3wFDC6we+vK31/i4GfpWlDc9/FWGBhGp8CvALskT7vC/wCGJI+X0n62/PQ+7DZNytFxFpJE4H3Ah8AZqrvvoLZEbEujd8M3AP8M3AccEuV8jPJDib3k52hXpnOWg8DbpHUU+5t6f+HA59M4zcA36lY37aSFqfxucDVZAeGWRGxJu3X81Xi2F/St4EdyA5Id6XpvwaulXQz8NM07bfAeZJ2A34aEY/V+C76VO07JktO+wP3pP0fxJtnwMtSzecXwOSI2LCp2+7DicAP0/hN6fMQ4JaI6AaekXR/mr9PrXgbVedvsJvsewP4N9K/U0RckJp8/g44Ke3DlFTu5+l3ui7txyHAbf0IrdrvDODLkj6exkeTHXyfIztI35qm7wOsjogFKc6XAdJ3d29EvJQ+LwfGkCWvTVWtWWkIcLmkCSmuvXPz5kfEE2n8SGAisCDFti3wlwZi2WJs9skBsuo42VnaHElLgVOBjbzZrLZNxSKv5Jb9s6TnJI0nSwCfr7KJ2cBFkkaQ/RDvA7YDXqzyo35j1b2EXPhjUPbL7uu642uBj0XWPn0a6SASEWco6yA+GlgsaUJE3JiaB44G7pL0uYi4r4/111TlO/4isCwiJtdY5ADgRWCXTd1mbySNBI4gS5hBdrAPsjPwqovQe7wNqfEb7HWR3LJ/BK6S9GPgr2nf3lKmxue+VPudTQE+SJa0X1XW3Nrz9/FavNmG39vvcX1uvItyjjNfAZ4FDiT7O34tN++V3LiA6yLi3BJi2Kxt9n0OkvapaE+fQFbVfZLsQA5vnsXXchPwVWB4RCytnBkRa4H5wCXA7ZG1Db8MPCHp2BSHJB2YFvk1qQ0cOLnOXbkXOK7nwJASUaXtgdWShuTXK2nPiJgXEecDa4DRkt4FPB4Rl5Ilt/F1xlFQ4zt+GNhJWWcsyvpgxqXxTwAjgfcBl0raYVO33YtjgOsjYkxE7B4Ro4EnyPb/k6nvYRfePAt/pFa8jerlN5i3VYoZshrCr9KyR+vNqudYsoPti+nzVEnbpN/EFGBBE8IdDryQEsO7gffUKPffZH0LB6c4t5c0kCebw8lqLt3Ap8iSfzX3AsdI2hmyvxtJYwYoxo62JdQchgGXpQPQRmAF2QO69gWulvQ1YF4f65hFduDvrfN4JlmT05TctJPJzvi+TlYNvomsnfgs4EZJZ/FmNb1XqSnmQuABSV1kbfWnVRT7RtqXp8j6GXo6CP8lHZxE9seyBDgHOEXS68AzwAX1xFFDre94BtnBfzjZb+2Hkp4FLgaOjIiVyi57vIS+z6T768S0nbxbyf7dnyZrS3+U7Pt6KSI2KOuYfku8wLImxFLr+5mVK/MKME7SIuAlsloqZAe+H0h6NS17ckR0pXwxH/gl8DfAtyJiVRNivRM4Q9KDZAnzd9UKpe/r+LRf2wLryGocA+VK4NZ08nU/b60tvCEilqe/v7uVXTTyOlmttjI5WwXfIW1bHEnDUj/ASLID7OER8Uyr4+oPZfc7rI2I77U6Fts8bQk1B7NKt6ez+K3Jzrg7KjGYDQTXHMzMrGCz75A2M7P+c3IwM7MCJwczMytwcjAzswInBzMzK3ByMDOzgv8PrzPDIqJOfNgAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure()\n",
    "sns.heatmap(data.corr())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(714, 6) (714, 1)\n"
     ]
    }
   ],
   "source": [
    "input_cols = ['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare']\n",
    "out_cols = ['Survived']\n",
    "\n",
    "X = data[input_cols]\n",
    "y = data[out_cols]\n",
    "\n",
    "#X.head()\n",
    "print (X.shape, y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Pclass     3.00\n",
       "Sex        0.00\n",
       "Age       22.00\n",
       "SibSp      1.00\n",
       "Parch      0.00\n",
       "Fare       7.25\n",
       "Name: 0, dtype: float64"
      ]
     },
     "execution_count": 146,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = data.reset_index(drop=True)\n",
    "df = data.sort_values('Survived')\n",
    "# print(df[df.Survived==0].shape[0])\n",
    "# # print(df[df['Survived']==1].shape[0])\n",
    "# # print(data.shape[0])\n",
    "#data[\"Pclass\"]\n",
    "val = X.iloc[0,:]\n",
    "val"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions for your Decision Tree learning algorithm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now is your chance to go ahead and implement some of the functionality needed for the decision tree learner. Remember that the _class_ variable for which we need to learn a tree is ```Survived```."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 314,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "      <th>SibSp</th>\n",
       "      <th>Parch</th>\n",
       "      <th>Fare</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>26.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>7.925</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Pclass  Sex   Age  SibSp  Parch   Fare\n",
       "0     3.0  1.0  26.0    0.0    0.0  7.925"
      ]
     },
     "execution_count": 314,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_right = pd.DataFrame([], columns=X.columns)\n",
    "val_row = X.iloc[2,:]\n",
    "# x_right = pd.concat([x_right,val_row])\n",
    "\n",
    "d={}\n",
    "for i in input_cols:\n",
    "    d[i]=val_row[i]\n",
    "x_right = x_right.append(d,ignore_index=True)\n",
    "x_right"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 322,
   "metadata": {},
   "outputs": [],
   "source": [
    "def divide_data(x_data, fkey, fval):\n",
    "\n",
    "    x_right = pd.DataFrame([], columns=x_data.columns)\n",
    "    x_left = pd.DataFrame([], columns=x_data.columns)\n",
    "    for ix in range(x_data.shape[0]):\n",
    "        # Retrieve the current value for the fkey column lets call it val\n",
    "        try:\n",
    "            #TODO\n",
    "            val = x_data.iloc[ix][fkey]\n",
    "            val_row = x_data.iloc[ix,:]\n",
    "#             print(\"key \",val)\n",
    "#             print(\"val \",val_row)\n",
    "        except:\n",
    "            #TODO\n",
    "            break\n",
    "            \n",
    "        \n",
    "        # Check where the row needs to go\n",
    "        if val > fval:\n",
    "            # pass the row to right\n",
    "            #TODO\n",
    "            x_right = x_right.append(val_row)\n",
    "#             print(\"val: \",val_row)\n",
    "#             print(\"right: \",x_right)\n",
    "        else:\n",
    "            # pass the row to left\n",
    "            #TODO\n",
    "            x_left = x_left.append(val_row)\n",
    "#             print(\"val,\" ,val_row)\n",
    "#             print(\"left: \",x_left)\n",
    "    \n",
    "    # return the divided datasets\n",
    "    #TODO\n",
    "    series = pd.Series()\n",
    "    series['0'] = x_left\n",
    "    series['1'] = x_right\n",
    "#     print(\"heaeaea: \",series)\n",
    "    return series\n",
    "\n",
    "def entropy(data, col):\n",
    "    #TODO\n",
    "    #total number of data\n",
    "    S = col.shape[0]\n",
    "    #number of class of data\n",
    "    values = col.unique()\n",
    "    H = 0\n",
    "    part1 = 0\n",
    "    part2 = 0\n",
    "    for i in values:\n",
    "        #number of data of class i\n",
    "        df1 = data.loc[col == i]\n",
    "        Sv = df1.shape[0]\n",
    "        #number of pos\n",
    "        pos = df1[df1.Survived==1].shape[0]\n",
    "        #number of neg\n",
    "        neg = df1[df1.Survived==0].shape[0]\n",
    "#         print(\"total: \",Sv)\n",
    "#         print(\"neg is {0}, pos is {1}\".format(neg,pos))\n",
    "        if(pos == 0):\n",
    "            part1 = 0\n",
    "        elif(pos!=0):\n",
    "            part1 = -(pos/Sv)*np.log2(pos/Sv)\n",
    "        if(neg == 0):\n",
    "            part2 = 0\n",
    "        elif(neg!=0):\n",
    "            part2 = -(neg/Sv)*np.log2(neg/Sv)\n",
    "        \n",
    "        H+=(Sv/S)*(part1 + part2)\n",
    "        #print(\"Entropy is: \",H)\n",
    "        \n",
    "    return H\n",
    "        \n",
    "\n",
    "def information_gain(xdata, fkey, fval):\n",
    "    #TODO\n",
    "    total = xdata.shape[0]\n",
    "    \n",
    "    #initial pos\n",
    "    pos1 = xdata[xdata.Survived==1].shape[0]\n",
    "    #initial neg\n",
    "    neg1 = xdata[xdata.Survived==0].shape[0]\n",
    "    Intial = ((-pos1/total)*np.log2(pos1/total)-(neg1/total)*np.log2(neg1/total))\n",
    "    Entropy = entropy(xdata,xdata[fkey])\n",
    "    return Intial - Entropy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 323,
   "metadata": {},
   "outputs": [],
   "source": [
    "# d = {\"age\":[1,1,2,3,3,3,2,1,1,3,1,2,2,3],\"buy\":[0,0,1,1,1,0,1,0,1,1,1,1,1,0]}\n",
    "# df = pd.DataFrame(d)\n",
    "# d = {\"age\":[1,1,2,3,3,3,2,1,1,3,1,2,2,4],\"buy\":[0,0,1,1,1,0,1,0,1,1,1,1,1,1]}\n",
    "# df1 = pd.DataFrame(d)\n",
    "# series = pd.Series()\n",
    "# series['0'] = df\n",
    "# series['1'] = df1\n",
    "# for i in range(len(series)):\n",
    "#     sp = series.iloc[i]\n",
    "#     print(sp)\n",
    "#     print()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 324,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pclass\n",
      "0.09568028536525675\n",
      "Sex\n",
      "0.21601606075154256\n",
      "Age\n",
      "0.12708102145051403\n",
      "SibSp\n",
      "0.024433098695422872\n",
      "Parch\n",
      "0.03006901595955691\n",
      "Fare\n",
      "0.45887916393629924\n"
     ]
    }
   ],
   "source": [
    "#Here X is your data without the Survived column. Run it after you have filled in the missing code above. \n",
    "\n",
    "for fx in X.columns:\n",
    "    print (fx) \n",
    "    print (information_gain(data, fx, data[fx].mean()))\n",
    "# for fx in df.columns:\n",
    "#     print (fx) \n",
    "#     print (information_gain(df, fx, df[fx].mean()))\n",
    "#     break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 325,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'a'"
      ]
     },
     "execution_count": 325,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d = {}\n",
    "l = ['a','b','c','d']\n",
    "for i in range(4):\n",
    "    d[l[i]] = 10\n",
    "\n",
    "maxValKey = max(d, key=d.get)\n",
    "maxValKey"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 374,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DecisionTree:\n",
    "    def __init__(self, depth=0, max_depth=5):\n",
    "        self.left = None\n",
    "        self.right = None\n",
    "        self.fkey = None\n",
    "        self.fval = None\n",
    "        self.max_depth = max_depth\n",
    "        self.depth = depth\n",
    "        self.target = None\n",
    "    \n",
    "    def train(self, train,d,dd):\n",
    "        Tree = dd\n",
    "        #print(train.head(5))\n",
    "        input_cols = ['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare']\n",
    "        out_cols = ['Survived']\n",
    "        col_dict = {'Pclass':0, 'Sex':1, 'Age':2, 'SibSp':3, 'Parch':4, 'Fare':5}\n",
    "        X_train = train[input_cols]\n",
    "        y_train = train[out_cols]\n",
    "        self.depth = d\n",
    "#         print (self.depth, '-'*10)\n",
    "#         print(\"max is \",self.max_depth)\n",
    "        if(self.depth > self.max_depth):\n",
    "            zero = y_train[y_train.Survived==0].shape[0]\n",
    "            one = y_train[y_train.Survived==1].shape[0]\n",
    "            if(zero > one):\n",
    "                return 0\n",
    "            else:\n",
    "                return 1\n",
    "\n",
    "\n",
    "        \n",
    "        # Get the best possible feature and division value (gains)\n",
    "        #TODO\n",
    "        # store the best feature (using min information gain)\n",
    "        #TODO\n",
    "        gains = {}\n",
    "        for fx in X_train.columns:\n",
    "            g = information_gain(train, fx, train[fx].mean())\n",
    "            gains[fx] = g\n",
    "        bestFeature = max(gains, key=gains.get)\n",
    "        # divide the dataset and reset index \n",
    "        #TODO\n",
    "        split_series = divide_data(train, bestFeature, train[bestFeature].mean())\n",
    "        #print(split_series)\n",
    "        Tree = {bestFeature:{}}\n",
    "        \n",
    "        for j in range(len(split_series)):\n",
    "            split_data = split_series.iloc[j]\n",
    "            if(j == 0):\n",
    "                value = '<= '+str(train[bestFeature].mean())\n",
    "            else:\n",
    "                value = '>'+str(train[bestFeature].mean())\n",
    "            self.depth+=1\n",
    "            Tree[bestFeature][value] = self.train(split_data,self.depth,Tree) \n",
    "            \n",
    "\n",
    " \n",
    "            \n",
    "            \n",
    "            \n",
    "            # Check the shapes and depth if it has exceeded max_depth or not in case it has make predictions \n",
    "            # TODO\n",
    "\n",
    "            # branch to right\n",
    "            #TODO\n",
    "\n",
    "            # branch to left\n",
    "            #TODO\n",
    "\n",
    "            #Make your prediction \n",
    "            #TODO\n",
    "        \n",
    "        return Tree\n",
    "    \n",
    "    def predict(self,rules,data):\n",
    "        if(isinstance(rules,int)):\n",
    "            return rules \n",
    "\n",
    "        for i in rules.keys():\n",
    "            if(i in data):\n",
    "                val = data[i]\n",
    "                for j in rules[i].keys():\n",
    "                    jj = str(val) + j\n",
    "                    if(eval(jj)):\n",
    "                        result = pred(rules[i][j],data)\n",
    "        return result\n",
    "    \n",
    "#     def predict(self, test):\n",
    "#         if test[self.fkey] > self.fval:\n",
    "#             pass\n",
    "#             # go right\n",
    "#             #TODO\n",
    "#         else:\n",
    "#             pass\n",
    "#             # go left\n",
    "#             #TODO"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Divide your data: separate Training and Test sets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 375,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(571, 7)\n",
      "(143, 7)\n"
     ]
    }
   ],
   "source": [
    "#Train:Test:0.8:0.2\n",
    "slices = int(data.shape[0]*0.8)\n",
    "training_data = data.head(slices)\n",
    "Testing_data = data.loc[slices:data.shape[0],:]\n",
    "print(Train.shape)\n",
    "print(Test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train your own decision tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 376,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "dt = DecisionTree()\n",
    "rules = dt.train(training_data,0,{})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 377,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\nprint (dt.fkey, dt.fval)\\nprint (dt.right.fkey, dt.right.fval)\\nprint (dt.left.fkey, dt.left.fval)\\n\\nprint (dt.right.right.fkey, dt.right.right.fval)\\nprint (dt.right.left.fkey, dt.right.left.fval)\\n\\n\\nprint (dt.left.right.fkey, dt.left.right.fval)\\nprint (dt.left.left.fkey, dt.left.left.fval)\\n'"
      ]
     },
     "execution_count": 377,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''\n",
    "print (dt.fkey, dt.fval)\n",
    "print (dt.right.fkey, dt.right.fval)\n",
    "print (dt.left.fkey, dt.left.fval)\n",
    "\n",
    "print (dt.right.right.fkey, dt.right.right.fval)\n",
    "print (dt.right.left.fkey, dt.right.left.fval)\n",
    "\n",
    "\n",
    "print (dt.left.right.fkey, dt.left.right.fval)\n",
    "print (dt.left.left.fkey, dt.left.left.fval)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 411,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'Fare': {'<= 35.444439229422024': {'Fare': {'<= 14.965582435597199': {'Fare': {'<= 9.233348507462688': {'Fare': {'<= 7.538276506024092': {'Age': {'<= 29.431818181818183': {'Fare': {'<= 6.505866666666667': 0, '>6.505866666666667': 0}}, '>29.431818181818183': 0}}, '>7.538276506024092': 0}}, '>9.233348507462688': 0}}, '>14.965582435597199': 0}}, '>35.444439229422024': 1}}\n",
      "Survived     0.000\n",
      "Pclass       3.000\n",
      "Sex          0.000\n",
      "Age         33.000\n",
      "SibSp        0.000\n",
      "Parch        0.000\n",
      "Fare         7.775\n",
      "Name: 571, dtype: float64\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "def pred(rules,data):\n",
    "    if(isinstance(rules,int)):\n",
    "        return rules \n",
    "\n",
    "    for i in rules.keys():\n",
    "        if(i in data):\n",
    "            val = data[i]\n",
    "            for j in rules[i].keys():\n",
    "                jj = str(val) + j\n",
    "                if(eval(jj)):\n",
    "                    result = pred(rules[i][j],data)\n",
    "    return result\n",
    "                    \n",
    "\n",
    "data = Testing_data.iloc[0,:]\n",
    "print(rules)\n",
    "print(data)\n",
    "print(pred(rules,data))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Make predictions for the first 10 and see if they are correct.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 359,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7.775\n",
      "33.0\n",
      "7.0542\n",
      "13.0\n",
      "13.0\n",
      "53.1\n",
      "8.6625\n",
      "21.0\n",
      "26.0\n",
      "7.925\n"
     ]
    }
   ],
   "source": [
    "for ix in Testing_data.index[:10]:\n",
    "    print(Testing_data.loc[ix]['Fare'])\n",
    "    print (dt.predict(Testing_data.loc[ix]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "testing_data.head(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now check for the entire test set how many you get correct: aim to get at least 75 percent accuracy !"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Now use SKLEARN: Decision tree and Random Forests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
